#!/usr/local/bin/python3
# -*- coding: utf-8 -*-

"""
@File    : stdf_parser_file_write_read.py
@Author  : Link
@Time    : 2022/12/24 12:07
@Mark    : 
"""
import os
from typing import Union, List, Dict

import pandas as pd
import numpy as np
from pandas import DataFrame as Df

from app_test.test_utils.wrapper_utils import Time
from common.app_variable import DataModule, GlobalVariable as GloVar, PtmdModule, TestVariable, \
    PartFlags, FailFlag, GlobalVariable, DatatType
from common.data_class_interface.for_analysis_stdf import PRR_HEAD, DTP_HEAD, PTMD_HEAD, BIN_HEAD, STDF_HEAD, DTR_HEAD
from parser_core.stdf_parser_func import PrrPartFlag, DtpTestFlag, PrrFailTestId
from ui_component.ui_common.my_text_browser import Print


class ParserData:
    @staticmethod
    def delete_temp_file():
        for path in TestVariable.PATHS:
            if os.path.exists(path):
                os.remove(path)

    @staticmethod
    def delete_swap_file(path):
        for each in GlobalVariable.PARSER_FILES:
            full_path = os.path.join(path, each)
            if os.path.exists(full_path):
                os.remove(full_path)

    @staticmethod
    def load_csv(path) -> Union[DataModule, None]:
        """
        TODO: 需要支援93k就在这边操作, 尽少的在C++中对程序进行修改
              93k主要注意@符号,分割第一个@
              OPT_FLG为0x0的数据就更新一下
        :return:
        """
        try:
            prr_df = pd.read_csv(os.path.join(path, GlobalVariable.PRR_FILE),
                                 header=None, names=PRR_HEAD.PRR_HEAD, dtype=GloVar.PRR_TYPE_DICT)
            dtp_df = pd.read_csv(os.path.join(path, GlobalVariable.DTP_PATH),
                                 header=None, names=DTP_HEAD.DTP_HEAD, dtype=GloVar.DTP_TYPE_DICT)
            ptmd_df = pd.read_csv(os.path.join(path, GlobalVariable.PTMD_PATH),
                                  header=None, names=PTMD_HEAD.PTMD_HEAD, dtype=GloVar.PTMD_TYPE_DICT)
            bin_df = pd.read_csv(os.path.join(path, GlobalVariable.BIN_PATH),
                                 header=None, names=BIN_HEAD.BIN_HEAD, dtype=GloVar.BIN_TYPE_DICT)
            log_df = pd.read_csv(os.path.join(path, GlobalVariable.LOG_PATH),
                                 header=None, names=DTR_HEAD.DTR_HEAD, dtype=GloVar.DTR_TYPE_DICT)
            # if 93k?
            new_ptmd_list = []
            # ========================= TODO: only for 93k
            cache_test_ptmd = dict()
            ptmd_df_dict_list = ptmd_df.to_dict(orient='records')
            for each in ptmd_df_dict_list:
                if isinstance(each[PTMD_HEAD.TEST_TXT], float):
                    Print.Error("有FTR测试项目没有设置TEST_TEXT! 是否是机台生成的STDF有问题?@?")
                    continue
                temp_split_text = each[PTMD_HEAD.TEST_TXT].split("@", 1)
                if len(temp_split_text) == 0:
                    continue
                else:
                    key = temp_split_text[0]
                    if key in cache_test_ptmd:
                        pass
                    else:
                        cache_test_ptmd[key] = each
                if each[PTMD_HEAD.OPT_FLAG] == 0:
                    temp_each = cache_test_ptmd[key]
                    each[PTMD_HEAD.PARM_FLG] = temp_each[PTMD_HEAD.PARM_FLG]
                    each[PTMD_HEAD.OPT_FLAG] = temp_each[PTMD_HEAD.OPT_FLAG]
                    each[PTMD_HEAD.RES_SCAL] = temp_each[PTMD_HEAD.RES_SCAL]
                    each[PTMD_HEAD.LLM_SCAL] = temp_each[PTMD_HEAD.LLM_SCAL]
                    each[PTMD_HEAD.HLM_SCAL] = temp_each[PTMD_HEAD.HLM_SCAL]
                    each[PTMD_HEAD.LO_LIMIT] = temp_each[PTMD_HEAD.LO_LIMIT]
                    each[PTMD_HEAD.HI_LIMIT] = temp_each[PTMD_HEAD.HI_LIMIT]
                    each[PTMD_HEAD.UNITS] = temp_each[PTMD_HEAD.UNITS]
                new_ptmd_list.append(each)
            ptmd_df = pd.DataFrame(new_ptmd_list)
            # ==================================================
            df_module = DataModule(prr_df=prr_df, dtp_df=dtp_df, ptmd_df=ptmd_df, bin_df=bin_df, log_df=log_df)
            return df_module
        except Exception as err:
            Print.Error(err)

    @staticmethod
    def save_hdf5(df_module: DataModule, file_path: str) -> bool:
        try:
            df_module.prr_df.to_hdf(file_path, "prr_df", mode="w")
            df_module.ptmd_df.to_hdf(file_path, "ptmd_df", mode="r+", format="table")
            df_module.dtp_df.to_hdf(file_path, "dtp_df", mode="r+")
            df_module.bin_df.to_hdf(file_path, "bin_df", mode="r+")
            df_module.log_df.to_hdf(file_path, "log_df", mode="r+")
            return True
        except Exception as err:
            print(err)
            return False

    @staticmethod
    def chuck_save_hdf5(df_module: DataModule, file_path: str) -> bool:
        """
        太慢, 禁止使用
        """
        try:
            df_module.prr_df.to_hdf(file_path, "prr_df", mode="w")
            df_module.ptmd_df.to_hdf(file_path, "ptmd_df", mode="r+", format="table")
            for test_id, df in df_module.dtp_df.groupby(DTP_HEAD.TEST_ID):
                df.to_hdf(file_path, "dtp_df_{}".format(test_id), mode="r+")
            df_module.bin_df.to_hdf(file_path, "bin_df", mode="r+")
            return True
        except Exception as err:
            print(err)
            return False

    @staticmethod
    def get_yield(prr_df, part_flag, read_fail) -> dict:
        """
        获取简单的良率信息
        :param read_fail:
        :param part_flag: PART_FLAGS = ('ALL', 'FIRST', 'RETEST', 'FINALLY', "XY_COORD")
        :param prr_df:
        :return:
        """
        df = ParserData.get_prr_data(prr_df, part_flag, read_fail)
        return ParserData.get_yield_data(df)

    @staticmethod
    def get_prr_data(prr_df, part_flag, read_fail) -> pd.DataFrame:
        df = prr_df
        if not read_fail:
            df = df[df.FAIL_FLAG == FailFlag.PASS]
        if part_flag == PartFlags.FIRST:
            df = df[df.PART_FLG & PrrPartFlag.FirstTest != PrrPartFlag.FirstTest]
        if part_flag == PartFlags.RETEST:
            df = df[df.PART_FLG & PrrPartFlag.FirstTest == PrrPartFlag.FirstTest]
        if part_flag == PartFlags.FINALLY:
            first_df = df[df.PART_FLG & PrrPartFlag.FirstTest != PrrPartFlag.FirstTest]
            retest_df = df[df.PART_FLG & PrrPartFlag.FirstTest == PrrPartFlag.FirstTest]
            first_pass_df = first_df[first_df.FAIL_FLAG == FailFlag.PASS]
            df = pd.concat([first_pass_df, retest_df])
        if part_flag == PartFlags.XY_COORD:
            if PRR_HEAD.DIE_ID in df:
                df1 = df[[PRR_HEAD.DIE_ID, PRR_HEAD.X_COORD, PRR_HEAD.Y_COORD]].groupby(
                    [PRR_HEAD.X_COORD, PRR_HEAD.Y_COORD]).last()
                df = df[df.DIE_ID.isin(df1.DIE_ID)]
            else:
                df1 = df[[PRR_HEAD.ID, PRR_HEAD.PART_ID, PRR_HEAD.X_COORD, PRR_HEAD.Y_COORD]].groupby(
                    [PRR_HEAD.X_COORD, PRR_HEAD.Y_COORD]).last()
                df = pd.merge(df.reset_index(), df1, on=[PRR_HEAD.ID, PRR_HEAD.PART_ID]).set_index(PRR_HEAD.DIE_ID)
        if part_flag == PartFlags.FIRST_XY:
            if PRR_HEAD.DIE_ID in df:
                df1 = df[[PRR_HEAD.DIE_ID, PRR_HEAD.X_COORD, PRR_HEAD.Y_COORD]].groupby(
                    [PRR_HEAD.X_COORD, PRR_HEAD.Y_COORD]).first()
                df = df[df.DIE_ID.isin(df1.DIE_ID)]
            else:
                df1 = df[[PRR_HEAD.ID, PRR_HEAD.PART_ID, PRR_HEAD.X_COORD, PRR_HEAD.Y_COORD]].groupby(
                    [PRR_HEAD.X_COORD, PRR_HEAD.Y_COORD]).first()
                df = pd.merge(df.reset_index(), df1, on=[PRR_HEAD.ID, PRR_HEAD.PART_ID]).set_index(PRR_HEAD.DIE_ID)
        return df

    @staticmethod
    def get_yield_data(df: pd.DataFrame):
        pass_qty = len(df[df.FAIL_FLAG == FailFlag.PASS])
        qty = len(df)
        if qty == 0:
            pass_yield = "0.0%"
        else:
            pass_yield = '{}%'.format(round(pass_qty / qty * 100, 2))
        return {
            STDF_HEAD.QTY: qty,
            STDF_HEAD.PASS: pass_qty,
            STDF_HEAD.YIELD: pass_yield,
        }

    @staticmethod
    def load_prr_df(file_path: str, unit_id=1) -> Union[pd.DataFrame, None]:
        """
        理论上, HDF5数据都可以load不会有报错的
        :param unit_id:
        :param file_path:
        :return:
        """
        df = pd.read_hdf(file_path, key="prr_df")
        if not isinstance(df, Df):
            return None
        df[PRR_HEAD.DIE_ID] = df[PRR_HEAD.PART_ID] + unit_id * GlobalVariable.DIE_ID_ADD
        return df

    @staticmethod
    def load_ptmd_df(file_path: str, unit_id=1) -> Union[pd.DataFrame, None]:
        ptmd_df = pd.read_hdf(file_path, key="ptmd_df")
        if not isinstance(ptmd_df, Df):
            return None
        ptmd_df.insert(0, column="ID", value=unit_id)
        ptmd_df[STDF_HEAD.ID] = ptmd_df[STDF_HEAD.ID].astype(np.uint32)
        ptmd_df[PTMD_HEAD.TEXT] = ptmd_df[PTMD_HEAD.TEST_NUM].astype(str) + ":" + ptmd_df[PTMD_HEAD.TEST_TXT]
        return ptmd_df

    @staticmethod
    @Time()
    def load_hdf5_analysis(
            file_path: str, part_flag: int, read_fail: int, unit_id: int, old=True,
            quick: bool = False, sample_num: int = 1E4
    ) -> Union[DataModule, None]:
        """
        :return: 在tree中处理并返回
        """
        try:
            prr_df = pd.read_hdf(file_path, key="prr_df")
            dtp_df = pd.read_hdf(file_path, key="dtp_df")
            ptmd_df = pd.read_hdf(file_path, key="ptmd_df")
            bin_df = pd.read_hdf(file_path, key="bin_df")
            try:
                log_df = pd.read_hdf(file_path, key="log_df")
            except Exception:
                log_df = pd.DataFrame()

        except Exception as err:
            Print.Warning("有数据读取失败/缺失, ID:{}, 是否STDF中没有数据? {}".format(unit_id, err))
            return None
        if not isinstance(prr_df, Df) or not isinstance(dtp_df, Df) \
                or not isinstance(ptmd_df, Df) or not isinstance(bin_df, Df) or not isinstance(log_df, Df):
            raise Exception("ERROR@!!!load_hdf5_analysis")
        if len(dtp_df) == 0:
            Print.Warning("有数据读取跳过, ID:{}, STDF中没有数据")
            return None
        prr_df.insert(0, column=STDF_HEAD.ID, value=unit_id)
        dtp_df.insert(0, column=STDF_HEAD.ID, value=unit_id)
        ptmd_df.insert(0, column=STDF_HEAD.ID, value=unit_id)
        log_df.insert(0, column=STDF_HEAD.ID, value=unit_id)
        prr_df[STDF_HEAD.ID] = prr_df[STDF_HEAD.ID].astype(np.uint32)
        dtp_df[STDF_HEAD.ID] = dtp_df[STDF_HEAD.ID].astype(np.uint32)
        ptmd_df[STDF_HEAD.ID] = ptmd_df[STDF_HEAD.ID].astype(np.uint32)
        log_df[STDF_HEAD.ID] = log_df[STDF_HEAD.ID].astype(np.uint32)

        prr_df[PRR_HEAD.DIE_ID] = prr_df[PRR_HEAD.PART_ID] + unit_id * GlobalVariable.DIE_ID_ADD
        # log_df[PRR_HEAD.DIE_ID] = log_df[PRR_HEAD.PART_ID] + unit_id * GlobalVariable.DIE_ID_ADD
        # TODO: TEXT看情况是否需要TEST_NUM
        ptmd_df[PTMD_HEAD.TEXT] = ptmd_df[PTMD_HEAD.TEST_NUM].astype(str) + ":" + ptmd_df[PTMD_HEAD.TEST_TXT]
        prr_df = ParserData.get_prr_data(prr_df, part_flag, read_fail)
        if quick and len(prr_df) > sample_num:
            prr_df = prr_df.sample(sample_num)
        dtp_df[DTP_HEAD.DIE_ID] = dtp_df[DTP_HEAD.PART_ID] + unit_id * GlobalVariable.DIE_ID_ADD
        dtp_df = dtp_df[dtp_df.PART_ID.isin(prr_df.PART_ID)]

        # =================== Future Delete
        if old:
            temp_fail_exec = dtp_df.TEST_FLG & DtpTestFlag.TestFailed == DtpTestFlag.TestFailed
            temp_fail = dtp_df[temp_fail_exec].copy()
            temp_pass = dtp_df[~temp_fail_exec].copy()
            temp_pass[DTP_HEAD.FAIL_FLG] = FailFlag.PASS
            temp_fail[DTP_HEAD.FAIL_FLG] = FailFlag.FAIL
            dtp_df = pd.concat([temp_pass, temp_fail])
            dtp_df[DTP_HEAD.FAIL_FLG] = dtp_df[DTP_HEAD.FAIL_FLG].astype(np.uint8)

        return DataModule(prr_df=prr_df, dtp_df=dtp_df, ptmd_df=ptmd_df, bin_df=bin_df, log_df=log_df)

    @staticmethod
    @Time()
    def contact_ptmd_list(args: List[pd.DataFrame], subset: str = PTMD_HEAD.TEXT) -> pd.DataFrame:
        """
        0.417 -> 0.007.
        """
        if len(args) == 1:
            ptmd_df = args[0]
            ptmd_df[PTMD_HEAD.NEW_TEST_ID] = ptmd_df[PTMD_HEAD.TEST_ID]
            return ptmd_df
        ptmd_df = pd.concat(args)
        # ptmd_df[PTMD_HEAD.REAL_TEST_ID] = ptmd_df[PTMD_HEAD.TEST_ID]
        # df1 = ptmd_df[[PTMD_HEAD.REAL_TEST_ID, PTMD_HEAD.TEXT]].groupby([PTMD_HEAD.TEXT]).last()
        # ptmd_df = ptmd_df[ptmd_df.REAL_TEST_ID.isin(df1.REAL_TEST_ID)]
        # TEST_ID = np.arange(1, len(ptmd_df) + 1, 1)
        # ptmd_df[PTMD_HEAD.TEST_ID] = TEST_ID
        new_ptmd = []
        for new_test_id, (_, df) in enumerate(ptmd_df.groupby(subset, sort=False)):
            df[PTMD_HEAD.NEW_TEST_ID] = new_test_id
            new_ptmd.append(df)
        ptmd_df = pd.concat(new_ptmd)
        ptmd_df[PTMD_HEAD.NEW_TEST_ID] = ptmd_df[PTMD_HEAD.NEW_TEST_ID].astype(np.int32)
        ptmd_df[STDF_HEAD.ID] = ptmd_df[STDF_HEAD.ID].astype(np.uint32)
        return ptmd_df

    @staticmethod
    @Time()
    def contact_ptmd_list_drop(ptmd_df: pd.DataFrame, subset: str = PTMD_HEAD.TEXT) -> pd.DataFrame:
        """
        0.417 -> 0.007.
        subset: PTMD_HEAD.TEST_NUM, PTMD_HEAD.TEST_TXT, PTMD_HEAD.TEXT
        """
        return ptmd_df.drop_duplicates(subset=[subset], keep="last")

    @staticmethod
    @Time()
    def contact_data_module(args: List[DataModule]):
        """
        TODO: ID也是用来和Summary链接的桥梁 2023年此API作废
        :param args:
        :return:
        """
        if len(args) == 1:
            return args[0]
        prr_df_list: list = list()
        dtp_df_list: list = list()
        ptmd_df_list: list = list()
        bin_df_list: list = list()
        for data_module in args:
            prr_df_list.append(data_module.prr_df)
            dtp_df_list.append(data_module.dtp_df)
            ptmd_df_list.append(data_module.ptmd_df)
            bin_df_list.append(data_module.bin_df)
        prr_df = pd.concat(prr_df_list)
        dtp_df = pd.concat(dtp_df_list)
        ptmd_df = pd.concat(ptmd_df_list)
        bin_df = pd.concat(bin_df_list)
        if not bin_df.empty:
            bin_df.drop_duplicates(subset=[BIN_HEAD.BIN_TYPE, BIN_HEAD.BIN_NUM], inplace=True, keep="last")

        new_test_id = 0
        ptmd_dict = {}  # 需要生成一份新的PTMD数据, 不绑定ID了
        new_dtps = []
        dtp_dict = dict()
        for (_id, _test_id), _dtp_df in dtp_df.groupby([STDF_HEAD.ID, DTP_HEAD.TEST_ID], sort=False):
            key = "{}-{}".format(_id, _test_id)
            dtp_dict[key] = _dtp_df

        for text, df in ptmd_df.groupby(PTMD_HEAD.TEXT, sort=False):
            new_test_id += 1
            for row in df.itertuples():  # type:PtmdModule
                # TODO: 1. 取出所有旧的TEST_ID 2. 替换成新的TEST_ID
                # start = time.perf_counter()
                key = "{}-{}".format(row.ID, row.TEST_ID)
                _dtp_df = dtp_dict[key]
                _dtp_df[DTP_HEAD.TEST_ID] = new_test_id
                new_dtps.append(_dtp_df)
                # use_time = round(time.perf_counter() - start, 3)
                # print("func: {} exec time: {}.".format("loc", use_time))
                ptmd_dict[new_test_id] = row

        ptmd_df = pd.DataFrame(ptmd_dict.values())
        for k, v in GloVar.PTMD_TYPE_DICT.items():
            ptmd_df[k] = ptmd_df[k].astype(v)
        ptmd_df[PTMD_HEAD.TEST_ID] = ptmd_dict.keys()
        dtp_df = pd.concat(new_dtps)
        return DataModule(prr_df=prr_df, dtp_df=dtp_df, ptmd_df=ptmd_df, bin_df=bin_df)

    @staticmethod
    @Time()
    def new_contact_data_module(args: List[DataModule], drop_text: str, skip_ftr: bool = False) -> DataModule:
        """
        :param args:
        :param drop_text: 设置TEST_ID去重模型
        :param skip_ftr: 不看FTR
        优化方案:
            不修改READ_ID, 可以大量减少时间 -> 2.487 => (0.199 + 0.007.)
        """
        prr_df_list: list = list()
        dtp_df_list: list = list()
        ptmd_df_list: list = list()
        bin_df_list: list = list()
        log_df_list: list = list()
        for data_module in args:
            prr_df_list.append(data_module.prr_df)
            dtp_df_list.append(data_module.dtp_df)
            ptmd_df_list.append(data_module.ptmd_df)
            bin_df_list.append(data_module.bin_df)
            log_df_list.append(data_module.log_df)
        prr_df = pd.concat(prr_df_list)
        dtp_df = pd.concat(dtp_df_list)
        log_df = pd.concat(log_df_list)
        ptmd_df = ParserData.contact_ptmd_list(ptmd_df_list,
                                               drop_text)  # 这里面不仅有旧的TEST_ID, 还根据TEXT的GROUP给每个TEXT安排了一个NEW_TEST_ID
        if skip_ftr:
            ptmd_df = ptmd_df[ptmd_df[PTMD_HEAD.DATAT_TYPE] != DatatType.FTR]
        bin_df = pd.concat(bin_df_list)
        if not bin_df.empty:
            bin_df.drop_duplicates(subset=[BIN_HEAD.BIN_TYPE, BIN_HEAD.BIN_NUM], inplace=True, keep="last")
        # log_df.set_index([STDF_HEAD.ID, PRR_HEAD.PART_ID], inplace=True)
        prr_df.set_index([PRR_HEAD.DIE_ID], inplace=True)
        prr_df[PRR_HEAD.DA_GROUP] = "*"
        prr_df[PRR_HEAD.FAIL_TEST_ID] = PrrFailTestId.Pass  # TODO: 20230311添加, 类似FailTestNum, -1表示PASS
        # 将所有的TEST_ID弄清楚
        ptmd_df_new_test_id = ptmd_df[[PTMD_HEAD.ID, PTMD_HEAD.TEST_ID, PTMD_HEAD.NEW_TEST_ID]]
        dtp_df = pd.merge(dtp_df, ptmd_df_new_test_id, on=[PTMD_HEAD.ID, PTMD_HEAD.TEST_ID])

        temp_fail_exec = dtp_df.TEST_FLG & DtpTestFlag.TestFailed == DtpTestFlag.TestFailed
        temp_fail = dtp_df[temp_fail_exec].copy()
        temp_pass = dtp_df[~temp_fail_exec].copy()
        temp_pass[DTP_HEAD.FAIL_FLG] = FailFlag.PASS
        temp_fail[DTP_HEAD.FAIL_FLG] = FailFlag.FAIL
        dtp_df = pd.concat([temp_pass, temp_fail])
        dtp_df[DTP_HEAD.FAIL_FLG] = dtp_df[DTP_HEAD.FAIL_FLG].astype(np.uint8)

        return DataModule(prr_df=prr_df, dtp_df=dtp_df, ptmd_df=ptmd_df, bin_df=bin_df,
                          ptmd_df_limit=ParserData.contact_ptmd_list_drop(ptmd_df, drop_text), log_df=log_df)
